package com.gotrade.apirepeatrequest.aspect;

import com.gotrade.apirepeatrequest.annotation.NoRepeatSubmission;
import com.gotrade.apirepeatrequest.common.JacksonSerializer;
import com.gotrade.apirepeatrequest.model.Result;
import lombok.extern.slf4j.Slf4j;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.ValueOperations;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import javax.servlet.http.HttpServletRequest;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;

/**
 * @author jason.tang
 * @create 2019/4/8
 * @description
 */

@Slf4j
@Aspect
@Component
public class NoRepeatSubmissionAspect {

    @Autowired
    RedisTemplate<String, String> redisTemplate;

    /**
     * 环绕通知
     * @param pjp
     * @param ars
     * @return
     */
    @Around("execution(public * com.gotrade.apirepeatrequest.controller..*.*(..)) && @annotation(ars)")
    public Object doAround(ProceedingJoinPoint pjp, NoRepeatSubmission ars) {
        ValueOperations<String, String> opsForValue = redisTemplate.opsForValue();
        try {
            if (ars == null) {
                return pjp.proceed();
            }

            HttpServletRequest request = ((ServletRequestAttributes) Objects.requireNonNull(RequestContextHolder.getRequestAttributes())).getRequest();

            String token = request.getHeader("Token");
            if (!checkToken(token)) {
                return Result.failure("Token无效");
            }
            String servletPath = request.getServletPath();
            String jsonString = this.getRequestParasJSONString(pjp);
            String sha1 = this.generateSHA1(jsonString);

            // key = token + servlet path
            String key = token + "-" + servletPath + "-" + sha1;

            log.info("\n{\n\tServlet Path: {}\n\tToken: {}\n\tJson String: {}\n\tSHA-1: {}\n\tResult Key: {} \n}", servletPath, token, jsonString, sha1, key);

            // 如果Redis中有这个key, 则url视为重复请求
            if (opsForValue.get(key) == null) {
                Object o = pjp.proceed();
                opsForValue.set(key, String.valueOf(0), 3, TimeUnit.SECONDS);
                return o;
            } else {
                return Result.failure("请勿重复请求");
            }
        } catch (Throwable e) {
            e.printStackTrace();
            return Result.failure("验证重复请求时出现未知异常");
        }
    }

    /**
     * 获取请求参数
     * @param pjp
     * @return
     */
    private String getRequestParasJSONString(ProceedingJoinPoint pjp) {
        String[] parameterNames = ((MethodSignature) pjp.getSignature()).getParameterNames();
        ConcurrentHashMap<String, String> args = null;

        if (Objects.nonNull(parameterNames)) {
            args = new ConcurrentHashMap<>(parameterNames.length);
            for (int i = 0; i < parameterNames.length; i++) {
                String value = pjp.getArgs()[i] != null ? pjp.getArgs()[i].toString() : "null";
                args.put(parameterNames[i], value);
            }
        }
        return JacksonSerializer.toJSONString(args);
    }

    private boolean checkToken(String token) {
        if (token == null || token.isEmpty()) {
            return false;
        }
        return true;
    }

    private String generateSHA1(String str){
        if (null == str || 0 == str.length()){
            return null;
        }
        char[] hexDigits = { '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
                'a', 'b', 'c', 'd', 'e', 'f'};
        try {
            MessageDigest mdTemp = MessageDigest.getInstance("SHA1");
            mdTemp.update(str.getBytes(StandardCharsets.UTF_8));

            byte[] md = mdTemp.digest();
            int j = md.length;
            char[] buf = new char[j * 2];
            int k = 0;
            for (int i = 0; i < j; i++) {
                byte byte0 = md[i];
                buf[k++] = hexDigits[byte0 >>> 4 & 0xf];
                buf[k++] = hexDigits[byte0 & 0xf];
            }
            return new String(buf);
        } catch (NoSuchAlgorithmException e) {
            e.printStackTrace();
        }
        return null;
    }
}
